-
Notifications
You must be signed in to change notification settings - Fork 180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: torch.compile and custom_op support #554
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This was referenced Oct 24, 2024
abcdabcd987
force-pushed
the
lequn/1023-torchlib
branch
from
October 25, 2024 02:25
a870e3e
to
f28464d
Compare
yzh119
added a commit
that referenced
this pull request
Oct 25, 2024
The block sparse attention unittests failed as noted in #554, this PR fixes the issue.
yzh119
approved these changes
Oct 25, 2024
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the huge improvement, I have left some tiny suggestions.
yzh119
added a commit
that referenced
this pull request
Oct 26, 2024
#554 didn't update the `batch_prefill.cu` (which was used in AOT mode) according to the API change. This PR fixes the issue. cc @abcdabcd987
yzh119
pushed a commit
that referenced
this pull request
Oct 26, 2024
Fix bugs introduced in #554 1. Function signature change for `chain_speculative_sampling()` pybind in aot. 2. `packbits()` uses a str default value, which is not supported by PyTorch 2.4. This PR added a workaround. 3. For Pytorch < 2.4, the two decorators (`register_custom_op()` and `register_fake_op()`) should return identity function instead of `None`.
yzh119
added a commit
that referenced
this pull request
Dec 17, 2024
🤖 I have created a release *beep* *boop* --- ## [0.2.0](v0.1.6...v0.2.0) (2024-12-17) [Release Blog](https://flashinfer.ai/2024/12/16/flashinfer-v02-release.html). ### Features * add `rotary_dim` argument to rope APIs for partial apply rope ([#599](#599)) ([eb9bc71](eb9bc71)) * add a `use_softmax` field in variant class ([#533](#533)) ([d81af97](d81af97)) * add an option `non_blocking` to plan function ([#622](#622)) ([560af6f](560af6f)) * add gemma_rmsnorm and gemma_fused_add_rmsnorm ([#477](#477)) ([1a6b17e](1a6b17e)) * add group size 3 to GQA decode dispatch ([#558](#558)) ([6227562](6227562)) * add JIT compilation support for FA3 templates ([#672](#672)) ([d4e8d79](d4e8d79)) * allow the cascade kernels to be executed using varying sequence lenghts ([#627](#627)) ([92ac440](92ac440)) * CUDAGraph compatibility of multi-level cascade inference APIs ([#586](#586)) ([2332e8a](2332e8a)) * fix the maximal grid dimension in prefill planning with CUDA graphs ([#639](#639)) ([86ca89a](86ca89a)) * improve the precision of the FusedAddRMSNormKernel function ([#587](#587)) ([c7dc921](c7dc921)) * JIT compilation ([#507](#507)) ([3613a5b](3613a5b)) * modify group-gemm stage number ([#497](#497)) ([52dab1d](52dab1d)) * non-contiguous query with paged kv cache ([#553](#553)) ([89f2c4a](89f2c4a)) * pass a dynamic token count to the cascade kernels ([#635](#635)) ([5fe9f7d](5fe9f7d)) * simplify prefill JIT compilation ([#605](#605)) ([fe4f898](fe4f898)) * specify gemm backend ([#648](#648)) ([0cc1a51](0cc1a51)) * support cached cos/sin in rope APIs ([#585](#585)) ([83e541d](83e541d)) * support huggingface transformer style rope interface ([#568](#568)) ([4f40420](4f40420)) * support sm90 cutlass group gemm ([#509](#509)) ([794bdda](794bdda)) * torch custom_op fix for rope ([#569](#569)) ([3e104bc](3e104bc)) * torch custom_op support: norm ([#552](#552)) ([f6e0010](f6e0010)) * torch.compile and custom_op support ([#554](#554)) ([9bf916f](9bf916f)) * warmup for jit kernel tests ([#629](#629)) ([8f5f349](8f5f349)) ### Bug Fixes * AOT compiler flags on non-sm90 ([#522](#522)) ([0aa4726](0aa4726)) * batch decode kernel redundant store output to gmem ([#505](#505)) ([90e42a7](90e42a7)) * compatible with torch 2.2 ([#478](#478)) ([ac41d1b](ac41d1b)) * #452 ([b53a46f](b53a46f)) * remove redundant load ([#495](#495)) ([2de16b0](2de16b0)) * update bmm fp8 test ([#487](#487)) ([45eac04](45eac04)) ### Performance Improvements * accelerate JIT compilation speed ([#618](#618)) ([eaf73fd](eaf73fd)) * Dense and sparse customizable flashattention-3 template ([#667](#667)) ([51236c9](51236c9)) * fix prefill kernel performance degradation (step 1) ([#602](#602)) ([595cf60](595cf60)) * fix the performance issue of `append_paged_kv_cache` ([#588](#588)) ([e15f7c9](e15f7c9)) * improve parallelism in RoPE with pos_ids ([#609](#609)) ([ff05155](ff05155)) * improve plan performance by using non-blocking memcpy ([#547](#547)) ([41ebe6d](41ebe6d)) * reduce the read and write of shared memory in the FusedAddRMSNormKernel ([#592](#592)) ([2043ca2](2043ca2)) * reduce total_num_tiles_q by one ([#644](#644)) ([553ace5](553ace5)) * remove unnecessary contiguous operation in block sparse attention ([#561](#561)) ([7a7ad46](7a7ad46)) * speedup jit compilation of prefill attention kernels ([#632](#632)) ([a059586](a059586)) * use cuda-core implemention for io-bound block-sparse attention ([#560](#560)) ([3fbf028](3fbf028)) --- This PR was generated with [Release Please](https://github.com/googleapis/release-please). See [documentation](https://github.com/googleapis/release-please#release-please). --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Zihao Ye <expye@outlook.com>
yzh119
added a commit
that referenced
this pull request
Dec 30, 2024
We observe performance degradation for small operations in flashinfer v0.2 because of the overhead of `torch.library.custom_op` introduced in #554. This PR disables torch custom operator registrations first, we can add them back with lightweight registration later: https://github.com/vllm-project/vllm/blob/36e76700453924c8d421db99af70a88a1df835cd/vllm/utils.py#L1660-L1674 cc @zhyncs @abcdabcd987 @youkaichao
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Follow up of #552. This PR adds torch library annotation to all FlashInfer kernels so that torch.compile can recognize the kernels. Most changes are tedious.
I manually ran subsets of pytest test cases when I made these changes, but since there are too many of them and also some of them didn't pass even before I made the change, I cannot guarantee it's all working. To run tests with torch.compile, pass
FLASHINFER_TEST_TORCH_COMPILE=1
env.Notable changes:
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
depending onreturn_lse
. This causes trouble fortorch.compile
. I changed the pybind interface to accept amaybe_lse: Optional[torch.Tensor]
and only return one tensor. The allocation of the lse tensor is moved to Python side. The Python API does not change.chain_speculative_sampling
pybind: Move the allocation ofaccepted
andemitted
from C++ to Python. This is becausetorch.compile
doesn't like returning input tensor as output tensor. The Python API does not change.Piggyback changes:
BatchPrefillWithRaggedKVCacheWrapper.plan
: Bugfix qo_indptr not on CPUmerge_state
: Fix typo in docsrun_return_lse(...)
torun(..., return_lse=True)
because torch.compile does not recognizefunctools.partial
.flashinfer.xxx()
toflashinfer.<module>.xxx()
so that the monkeypatch works.Unsupported for torch.compile:
flashinfer.quantization.segment_packbits
: Because it's data dependent.Untouched:
sparse.py
: Tests didn't pass beforehand, so I skiped this. Also, it doesn't seem like need custom_op annotations, as it does not have CUDA kernels.Failed test cases:
test_batch_decode_with_paged_kv_cache[False-kv_dtype0-q_dtype0-True-0.0-NONE-NHD-128-4-4-1-54-12]